/*
  Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

  This file is part of Sunrise.

  Sunrise is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 3 of the License, or
  (at your option) any later version.

  Sunrise is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.
  
*/

#include<iostream>
#include "CCfits/CCfits"
#include "counter.h"
#include "polychromatic_grid.h"
#include "optical.h"
#include "scatterer.h"
#include "constants.h"
#include "dust_model.h"
#include "emission.h"
#include "emission-fits.h"
#include "full_sed_emergence.h"
#include "full_sed_grid.h"
#include "emergence-fits.h"
#include "mcrx-units.h"
#include "grid_factory.h"
#include "grid-fits.h"
#include "shoot.h"
#include "fits-utilities.h"
#include "boost/lexical_cast.hpp"
#include "boost/shared_ptr.hpp"
#include "density_generator.h"
#include "preferences.h"
#include "model.h"
#include "grain_model_hack.h"
#include "p04_grain_model.h"
#include "wd01_grain_model.h"
#include "mlt_xfer.h"
#include "imops.h"
#include "boost/program_options.hpp"

using namespace mcrx;
using namespace CCfits;
using namespace std; 
using namespace boost;
using namespace blitz;

using boost::lexical_cast;
namespace po = boost::program_options;

typedef local_random T_rng_policy;
typedef scatterer<polychromatic_scatterer_policy, T_rng_policy> T_scatterer;
typedef T_scatterer::T_lambda T_lambda;
typedef T_scatterer::T_biaser T_biaser;
typedef dust_model<T_scatterer, cumulative_sampling, T_rng_policy> T_dust_model;
typedef full_sed_emergence T_emergence;

typedef full_sed_grid<adaptive_grid> T_grid;
typedef emission<polychromatic_policy, T_rng_policy> T_emission;


void add_cam(T_float theta, T_float phi, T_float cd, 
	     T_float fov_fact, T_float gridsize,
	     vector<vec3d>& pos, 
	     vector<vec3d>& dir,
	     vector<vec3d>& up, 
	     vector<T_float>& fov)
{
  dir.push_back(-vec3d(sin(theta*constants::pi/180.)*cos(phi),
		       sin(theta*constants::pi/180)*sin(phi),
		       cos(theta*constants::pi/180)));
  pos.push_back(-dir.back()*cd);
  up.push_back(vec3d(0,0,1));
  fov.push_back(fov_fact*sqrt(2)*gridsize/cd);
}


class pascucci_factory: public grid_factory<T_grid::T_data> {
public:
  const T_densities rho0_, n_tol_;
  const T_float ri, ro, rd, zd, r_tol_;
  const int maxl, n_lambda;
  
  pascucci_factory (const T_densities& rho0, const T_densities& n_tol,
		    T_float r_tol,
		    int ml, int nl):
    ri(1), ro(1000), rd (ro/2), zd(ro/8),
    rho0_(rho0), n_tol_(n_tol), r_tol_(r_tol),
    maxl(ml), n_lambda (nl) {};

  T_densities calc_rho(const vec3d& p) {
    T_densities rho;
    resize_like(rho, rho0_);
    rho = 0;
    const T_float r = sqrt(p[0]*p[0] + p[1]*p[1]); 
    const T_float z = p[2];
    if ((r < ro) && (r > ri)) {
      // we are in the allowed range of radii, calculate the density
      const T_float h=zd*pow(r/rd,1.125);
      const T_float f2=exp(-constants::pi/4*pow(z/h,2));
      rho = rho0_ *(rd/r)*f2;
    }
    return rho;
  };

  virtual bool refine_cell_p (const T_cell& c, int level) {
    if (level>maxl) 
      return false;
    //if (level<=4) return true; // a minimum amount of "prerefinement"
    // refine until cell column density < n_tol

    // simone uses maxlevel=10+log(r/rmin)
    const vec3d min = c.getmin();
    const vec3d max = c.getmax();
    /*    
    T_densities rhoav
      ((calc_rho(c.getcenter())+
	calc_rho(min)+
	calc_rho(vec3d(min[0],min[1],max[2]))+
	calc_rho(vec3d(min[0],max[1],min[2]))+
	calc_rho(vec3d(min[0],max[1],max[2]))+
	calc_rho(vec3d(max[0],min[1],min[2]))+
	calc_rho(vec3d(max[0],min[1],max[2]))+
	calc_rho(vec3d(max[0],max[1],min[2]))+
	calc_rho(max))/9);
    */
    const T_float sz=sqrt(dot(c.getsize(),c.getsize()));
    T_densities rhoav
      (-log((exp(-sz*calc_rho(c.getcenter()))+
	     exp(-sz*calc_rho(min))+
	     exp(-sz*calc_rho(vec3d(min[0],min[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(min[0],max[1],min[2])))+
	     exp(-sz*calc_rho(vec3d(min[0],max[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],min[1],min[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],min[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],max[1],min[2])))+
	     exp(-sz*calc_rho(max)))/9)/sz) ;
    
    // refine cell if cellsize/distance from source > r_tol_ or if it
    // is within one cellsize of the boundary (assuming some of the
    // densities are nonzero)
    vec3d cylr=c.getcenter();
    cylr[2]=0;
    if (blitz::any(rhoav>0) &&
	//( (dot(c.getsize(), c.getsize()/dot(c.getcenter(), c.getcenter())) >
	( (dot(c.getsize(), c.getsize()/dot(cylr,cylr)) >
	   r_tol_*r_tol_) ||
	  (2*abs(sqrt(dot(cylr,cylr)-ri))<
	   sqrt(dot(c.getsize(),c.getsize())))
	  ))
      return true;

    // refine cell if column density exceeds n_tol
    // length required for tau 1
    //T_densities abslen(64*n_tol_/rhosum);
    //if (condition_all(abslen>50))
      return condition_all (rhoav*rhoav*dot(c.getsize(),c.getsize()) > 
			    n_tol_*n_tol_);
      //else
      // hopelessly optically thick, just go by the radial criterion
      //return false;
  };
  virtual bool unify_cell_p (T_grid::const_local_iterator begin,
			     T_grid::const_local_iterator end) {
    return false;
    /*
    vec3d center = begin->getcenter();
    T_float vol = begin->volume();
    T_densities s ((begin++)->data()->get_absorber().opacity());
    T_densities ssq (s*s);
    int n=1;
    while (begin != end) {
      center += begin->getcenter();
      vol += begin->volume();
      T_densities o = (begin++)->data()->get_absorber().opacity();
      s+=o;
      ssq+=o*o;
      n++;
    }

    const T_densities mean (s/n);
    const T_densities stdev2 (ssq/n- mean*mean);
    const T_float r2=dot(center,center)/(n*n);
    return condition_all ((mean==0) || ((stdev2/(mean*mean))<tol) || 
			  (stdev2<r2*r2/(vol*vol)*tabs*tabs));
    */
  };

  virtual T_data get_data (const T_cell& c) {
    const vec3d min = c.getmin();
    const vec3d max = c.getmax();
    /*
    T_densities rhoav
      ((calc_rho(c.getcenter())+
	calc_rho(min)+
	calc_rho(vec3d(min[0],min[1],max[2]))+
	calc_rho(vec3d(min[0],max[1],min[2]))+
	calc_rho(vec3d(min[0],max[1],max[2]))+
	calc_rho(vec3d(max[0],min[1],min[2]))+
	calc_rho(vec3d(max[0],min[1],max[2]))+
	calc_rho(vec3d(max[0],max[1],min[2]))+
	calc_rho(max))/9);
    */
    const T_float sz=sqrt(dot(c.getsize(),c.getsize()));
    T_densities rhoav
      (-log((exp(-sz*calc_rho(c.getcenter()))+
	     exp(-sz*calc_rho(min))+
	     exp(-sz*calc_rho(vec3d(min[0],min[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(min[0],max[1],min[2])))+
	     exp(-sz*calc_rho(vec3d(min[0],max[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],min[1],min[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],min[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],max[1],min[2])))+
	     exp(-sz*calc_rho(max)))/9)/sz) ;
    //const T_densities rhoav = calc_rho(c.getcenter());
    return T_data ( rhoav, vec3d(0,0,0), T_lambda (n_lambda) );
  }
    
  virtual int n_threads () {return 1;};
  virtual int work_chunk_levels () {return 5;};
  virtual int estimated_levels_needed () {return 0;};
};


class uniform_factory: public grid_factory<T_grid::T_data> {
public:
  const T_densities rho_, n_tol_;
  const T_float rmax_;
  const int n_lambda;

  T_densities get_rho(const T_cell& c) {
    // return 0 if outside of radius rmax
    T_densities rho = (dot(c.getcenter(), c.getcenter())< rmax_*rmax_) ? rho_ : T_densities(rho_*0.);
    return rho;
  };

  uniform_factory (const T_densities& rho, T_densities n_tol, 
		   int nl, T_float rmax):
    rho_ (rho), n_lambda (nl), n_tol_(n_tol), rmax_(rmax) {};
  
  virtual bool refine_cell_p (const T_cell& c, int level) {
    // cell is refined to keep column density below n_tol
    T_densities rho = get_rho(c);
    T_densities n2 (dot(c.getsize(), c.getsize())*rho*rho);
    bool refine = all( n2 > (n_tol_*n_tol_));
    //if(!refine) cout << "No refinement level " << level << " at " << c.getcenter() << c.getsize() << endl;
    return refine;
  };
  virtual bool unify_cell_p (T_grid::const_local_iterator begin,
			     T_grid::const_local_iterator end) {
    return false;
  };
  virtual T_data get_data (const T_cell& c) {
    return T_data(get_rho(c), vec3d(0,0,0), T_lambda(n_lambda));
  }
    
  virtual int n_threads () {return 1;};
  virtual int work_chunk_levels () {return 5;};
  virtual int estimated_levels_needed () {return 0;};
};

shared_ptr<T_grid> make_grid(po::variables_map opt, 
			     const T_dust_model::T_scatterer_vector& sv,
			     const T_unit_map& units)
{
  const vec3i grid_n (20,20,40);
  const int maxlevel = opt["maxlevel"].as<int>();
  const vec3d grid_extent(1000,1000,1000);

  // opacity for the WD01 dust = 2.41e-6 kpc^2/Msun (=1.025e11 au^2/Msun)
  // for pascucci grid, n=rho0*rd*6.9078
  // which implies tau = 3.54e14 au^3/Msun * rho0

  // grid refinement tolerances
  const T_float r_tol =opt["r_tol"].as<T_float>();
  const T_float tau = opt["tau"].as<T_float>();
  const T_float tau_tol = opt["tau_tol"].as<T_float>();

  // Convert optical depth to density
  const T_float unit_opacity_rho = 1.0/sv[0]->opacity()(1);
  T_densities rho(1);
  rho = tau*unit_opacity_rho/(500*6.9078);
  T_densities n_tol(1); n_tol = tau_tol*unit_opacity_rho;
  cout << "Unit opacity density: " << unit_opacity_rho <<  ", n_tol: " << n_tol(0) << endl;

  // build grid
  cout << "building grid" << endl;
  pascucci_factory factory (rho,n_tol, r_tol, maxlevel, 0);

  // first create the adaptive grid
  boost::shared_ptr<adaptive_grid<T_grid::T_data> > gg
    (new adaptive_grid<T_grid::T_data> (-grid_extent, grid_extent,
				grid_n, factory));

  shared_ptr<T_grid> gr(new T_grid (gg, units));
  // calculate total dust mass
  T_densities mtot(1);
  mtot=0;
  for(T_grid::const_iterator c=gr->begin(); c!=gr->end(); ++c)
    mtot+=c->data()->densities()*c->volume();
  cout << "Total dust mass: " << mtot(0) << endl;

  return gr;
}


void pascucci_test (po::variables_map opt, 
		    shared_ptr<T_grid> gr, T_dust_model& model,
		    shared_ptr<T_emission> em, const T_unit_map& units,
		    const vector<vec3d>& pos,
		    const vector<vec3d>& dir,
		    const vector<vec3d>& up,
		    const vector<T_float>& fov,
		    const string& filename
		    )
{
  const bool use_metro = opt["use_metro"].as<bool>();
  const bool use_bidir = opt["use_bidir"].as<bool>();

  const int n_threads = opt["n_threads"].as<int>();
  const long n_rays = opt["n_rays"].as<long>(); 
  const long n_mutations =opt["n_mut"].as<long>();
  const long n_metropolis =opt["n_metro"].as<long>();
  const long n_seeds =opt["n_seeds"].as<long>();
  const int n_forced= use_metro ? 0 : opt["n_forced"].as<int>();
  const T_float I_rr=opt["i_rr"].as<T_float>();
  const int vis_ref_lambda = opt["vis_ref"].as<int>();
  const string output_file(opt["output-file"].as<string>());
  const int npix=100;


  cout << "setting up emergence" << endl;
  const int max_n_track= use_bidir ? 4 :-1;
  auto_ptr<T_emergence> e;
  if(max_n_track>0) {
    const int n_cams=(max_n_track+1)*(max_n_track+2)/2+1;
    vector<vec3d> ppos(n_cams, pos[0]);
    vector<vec3d> ddir(n_cams, dir[0]);
    vector<vec3d> uup(n_cams, up[0]);
    vector<T_float> ffov(n_cams, fov[0]);
    e.reset(new T_emergence(ppos,ddir,uup,ffov, 300));
  }
  else
    e.reset(new T_emergence(pos,dir,up,fov, 300));

  cout << "shooting" << endl;

  std::vector<T_rng::T_state> states= generate_random_states (n_threads);
  T_rng_policy rng;

  //i_split is huge with metro because we can't have split rays go
  //split themselves
  mlt_xfer<T_dust_model, T_grid, T_rng_policy> x (*gr, *em, *e, model, rng,  I_rr, 0,n_forced,use_metro?1e300:10,false, false);
  long n_rays_actual = 0;
  dummy_terminator t;
  T_biaser b(vis_ref_lambda);
  T_float mean_weight=1;

  if(use_metro||use_bidir) {
    x.rng().get_generator().seed(opt["seed"].as<int>());
    if(use_metro) {
      mean_weight=x.metropolis_driver(n_seeds, n_metropolis, n_mutations, b);
      n_rays_actual=n_metropolis*n_mutations;
    }
    else {
      x.max_n_track_ = max_n_track;
      const int ns_max = blitz::huge(int());
      const int nc_max = blitz::huge(int());
      shoot (x, bidir_shooter<T_biaser> (b, ns_max, nc_max), t,
	     states, n_rays, n_rays_actual, n_threads, true, n_threads, 0, "");
    }
  } 
  else {
    shoot (x, scatter_shooter<T_biaser> (b), t, states, n_rays,
	   n_rays_actual, n_threads, true, n_threads, 0, "");
    mean_weight=1;
  }

  T_float normalization = mean_weight/n_rays_actual;
  cout << "Total per-ray camera signal " << sum((*e->begin())->get_image())*normalization << endl;
  FITS output(filename, Write);
  e->write_images(output,normalization,units, "-STAR", false, false);

  return;
}


int main (int argc, char** argv)
{
  // Declare the supported options.
  po::options_description desc("Pascucci et al 04 benchmark. Allowed options");
  desc.add_options()
    ("output-file", po::value<string>(), "output file name")
    ("n_seeds", po::value<long>()->default_value(10000), "number of Metropolis seed paths")
    ("n_metro", po::value<long>()->default_value(100), "number of Metropolis runs")
    ("n_mut", po::value<long>()->default_value(1000000), "number of mutations per metropolis run")
    ("n_rays", po::value<long>()->default_value(1000000), "number of rays for normal runs")
    ("n_threads", po::value<int>()->default_value(8), "number of threads")
    ("maxlevel", po::value<int>()->default_value(13), 
     "max grid refinement level")
    ("n_forced", po::value<int>()->default_value(1), 
     "number of forced scatterings")
    ("vis_ref", po::value<int>()->default_value(1),
     "stellar radiation reference wavelength")
    ("i_rr", po::value<T_float>()->default_value(1e-1),
     "intensity where Russian Roulette starts")
    ("r_tol", po::value<T_float>()->default_value(0.2),
     "cell radial extent tolerance")
    ("tau", po::value<T_float>()->default_value(100),
     "edge-on V-band optical depth of disk")
    ("tau_tol", po::value<T_float>()->default_value(1),
     "cell optical depth refinement tolerance")
    ("seed", po::value<int>()->default_value(42), "random number seed")
    ("use_metro", po::value<bool>()->default_value(true), "Use MLT")
    ("use_bidir", po::value<bool>()->default_value(false), "Use bidirectional tracing")
    ("help", "print this page")
    ;

  po::positional_options_description p;
  p.add("output-file", 1);

  po::variables_map opt;
  po::store(po::command_line_parser(argc, argv).
          options(desc).positional(p).run(), opt);
  po::notify(opt);

  if (opt.count("help")) {
    cout << desc << "\n";
    return 1;
  }
  if(!opt.count("output-file")) {
    cout << "Must specify output file name on command line" << endl;
    return 1;
  }

  cout << "Running Pascucci et al 04 benchmark with parameters:\n";
  cout << "output-file = " << opt["output-file"].as<string>() << '\n'
       << "n_seeds = " << opt["n_seeds"].as<long>() << '\n'
       << "n_metro = " << opt["n_metro"].as<long>() << '\n'
       << "n_mut = " << opt["n_mut"].as<long>() << '\n'
       << "n_rays = " << opt["n_rays"].as<long>() << '\n'
       << "n_threads = " << opt["n_threads"].as<int>() << '\n'
       << "maxlevel = " << opt["maxlevel"].as<int>() << '\n'
       << "n_forced = " << opt["n_forced"].as<int>() << '\n'
       << "vis_ref = " << opt["vis_ref"].as<int>() << '\n'
       << "i_rr = " << opt["i_rr"].as<T_float>() << '\n'
       << "r_tol = " << opt["r_tol"].as<T_float>() << '\n'
       << "i_rr = " << opt["i_rr"].as<T_float>() << '\n'
       << "tau = " << opt["tau"].as<T_float>() << '\n'
       << "tau_tol = " << opt["tau_tol"].as<T_float>() << '\n'
       << "seed = " << opt["seed"].as<int>() << '\n'
       << "use_metro = " << opt["use_metro"].as<bool>() << '\n'
       << "use_bidir = " << opt["use_bidir"].as<bool>() << '\n'
       << "\n\n";

  const T_float cameradist = 50000;

  const string basename=opt["output-file"].as<string>();
  int imnum=0;

  T_unit_map units;
  units ["length"] = "au";
  units ["mass"] = "Msun";
  units ["luminosity"]= "W";
  units ["wavelength"] = "m";
  units ["L_lambda"] = "W/m";

  const int seed=opt["seed"].as<int>();
  mcrx::seed(seed);

  //need to communicate number of threads to the grain objects.
  Preferences prefs;
  prefs.setValue("n_threads", opt["n_threads"].as<int>());
  prefs.setValue("use_grain_temp_lookup", true);

  // read wavelengths
  vector<T_float> lambdas;
  lambdas.push_back(0.4e-6);
  lambdas.push_back(0.55e-6);
  lambdas.push_back(0.7e-6);
  array_1 lambda(&lambdas[0], TinyVector<int,1>(lambdas.size()), neverDeleteData);

  // load grain model
  T_dust_model::T_scatterer_vector sv;
  sv.push_back(shared_ptr<T_scatterer> 
	       (new p04_grain_model<polychromatic_scatterer_policy, mcrx_rng_policy>(word_expand("$HOME/dust_data/crosssections")[0], prefs, units)));
  T_dust_model model (sv);
  model.set_wavelength(lambda); // this also resamples the grain_model objects

  shared_ptr<T_grid> gr(make_grid(opt, sv, units));

  cout << "\nSetting up emission" << endl;
  blackbody bb(5800,3.91e26);
  shared_ptr<T_emission> em(new pointsource_emission<polychromatic_policy, T_rng_policy> (vec3d (.01, .01, -.01), bb.emission(lambda)));

  cout << bb.emission(lambda);
  const int lambda550 = lower_bound(lambdas.begin(), lambdas.end(),
				    0.55e-6) - lambdas.begin();
  cout << "550nm is lambda entry " << lambda550 << endl;

  // define cameras
  std::vector<vec3d> pos,dir,up;
  std::vector<T_float> fov;
  const T_float gridsize=gr->grid().getmax()[0];
  /*
  add_cam(12.5, .393, cameradist, 2., gridsize, pos, dir, up, fov);
  add_cam(42.5, .393, cameradist, 2., gridsize, pos, dir, up, fov);
  add_cam(65.5, .393, cameradist, 2., gridsize, pos, dir, up, fov);
  add_cam(77.5, .393, cameradist, 2., gridsize, pos, dir, up, fov);
  add_cam(88.5, .393, cameradist, 2., gridsize, pos, dir, up, fov);
  add_cam(88.5, .393, cameradist, 1., gridsize, pos, dir, up, fov);
  add_cam(88.5, .393, cameradist, .5, gridsize, pos, dir, up, fov);
  */
  add_cam(88.5, .393, cameradist, .25, gridsize, pos, dir, up, fov);
  if(opt["use_metro"].as<bool>() ||
     opt["use_bidir"].as<bool>() ) {
    const int n=pos.size();
    vector<vec3d>::const_iterator pi=pos.begin();
    vector<vec3d>::const_iterator di=dir.begin();
    vector<vec3d>::const_iterator ui=up.begin();
    vector<T_float>::const_iterator fi=fov.begin();
    for(int i=0;i<n;i++) {
      cout << "\n\n\nStarting image " << imnum << endl;
      
      pascucci_test(opt, gr, model, em, units, vector<vec3d>(pi++,pi),
		    vector<vec3d>(di++,di),
		    vector<vec3d>(ui++,ui),
		    vector<T_float>(fi++,fi),
		    basename+lexical_cast<string>(i)+".fits");
    }
  }
  else {
    // in this case we run all cameras at once
    pascucci_test(opt, gr, model, em, units, pos, dir, up, fov,
		    basename+"-all.fits");
  }
}
